{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第7章 支持向量机"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "分离超平面：$w^Tx+b=0$\n",
    "\n",
    "点到直线距离：$r=\\frac{|w^Tx+b|}{||w||_2}$\n",
    "\n",
    "$||w||_2$为2-范数：$||w||_2=\\sqrt[2]{\\sum^m_{i=1}w_i^2}$\n",
    "\n",
    "直线为超平面，样本可表示为：\n",
    "\n",
    "$w^Tx+b\\ \\geq+1$\n",
    "\n",
    "$w^Tx+b\\ \\leq+1$\n",
    "\n",
    "#### margin：\n",
    "\n",
    "**函数间隔**：$label(w^Tx+b)\\ or\\ y_i(w^Tx+b)$\n",
    "\n",
    "**几何间隔**：$r=\\frac{label(w^Tx+b)}{||w||_2}$，当数据被正确分类时，几何间隔就是点到超平面的距离\n",
    "\n",
    "为了求几何间隔最大，SVM基本问题可以转化为求解:($\\frac{r^*}{||w||}$为几何间隔，(${r^*}$为函数间隔)\n",
    "\n",
    "$$\\max\\ \\frac{r^*}{||w||}$$\n",
    "\n",
    "$$(subject\\ to)\\ y_i({w^T}x_i+{b})\\geq {r^*},\\ i=1,2,..,m$$\n",
    "\n",
    "分类点几何间隔最大，同时被正确分类。但这个方程并非凸函数求解，所以要先①将方程转化为凸函数，②用拉格朗日乘子法和KKT条件求解对偶问题。\n",
    "\n",
    "①转化为凸函数：\n",
    "\n",
    "先令${r^*}=1$，方便计算（参照衡量，不影响评价结果）\n",
    "\n",
    "$$\\max\\ \\frac{1}{||w||}$$\n",
    "\n",
    "$$s.t.\\ y_i({w^T}x_i+{b})\\geq {1},\\ i=1,2,..,m$$\n",
    "\n",
    "再将$\\max\\ \\frac{1}{||w||}$转化成$\\min\\ \\frac{1}{2}||w||^2$求解凸函数，1/2是为了求导之后方便计算。\n",
    "\n",
    "$$\\min\\ \\frac{1}{2}||w||^2$$\n",
    "\n",
    "$$s.t.\\ y_i(w^Tx_i+b)\\geq 1,\\ i=1,2,..,m$$\n",
    "\n",
    "②用拉格朗日乘子法和KKT条件求解最优值：\n",
    "\n",
    "$$\\min\\ \\frac{1}{2}||w||^2$$\n",
    "\n",
    "$$s.t.\\ -y_i(w^Tx_i+b)+1\\leq 0,\\ i=1,2,..,m$$\n",
    "\n",
    "整合成：\n",
    "\n",
    "$$L(w, b, \\alpha) = \\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$$\n",
    "\n",
    "推导：$\\min\\ f(x)=\\min \\max\\ L(w, b, \\alpha)\\geq \\max \\min\\ L(w, b, \\alpha)$\n",
    "\n",
    "根据KKT条件：\n",
    "\n",
    "$$\\frac{\\partial }{\\partial w}L(w, b, \\alpha)=w-\\sum\\alpha_iy_ix_i=0,\\ w=\\sum\\alpha_iy_ix_i$$\n",
    "\n",
    "$$\\frac{\\partial }{\\partial b}L(w, b, \\alpha)=\\sum\\alpha_iy_i=0$$\n",
    "\n",
    "带入$ L(w, b, \\alpha)$\n",
    "\n",
    "$\\min\\  L(w, b, \\alpha)=\\frac{1}{2}||w||^2+\\sum^m_{i=1}\\alpha_i(-y_i(w^Tx_i+b)+1)$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^Tw-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i-b\\sum^m_{i=1}\\alpha_iy_i+\\sum^m_{i=1}\\alpha_i$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\frac{1}{2}w^T\\sum\\alpha_iy_ix_i-\\sum^m_{i=1}\\alpha_iy_iw^Tx_i+\\sum^m_{i=1}\\alpha_i$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\alpha_iy_iw^Tx_i$\n",
    "\n",
    "$\\qquad\\qquad\\qquad=\\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)$\n",
    "\n",
    "再把max问题转成min问题：\n",
    "\n",
    "$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$\n",
    "\n",
    "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n",
    "\n",
    "$ \\alpha_i \\geq 0,i=1,2,...,m$\n",
    "\n",
    "以上为SVM对偶问题的对偶形式\n",
    "\n",
    "-----\n",
    "#### kernel\n",
    "\n",
    "在低维空间计算获得高维空间的计算结果，也就是说计算结果满足高维（满足高维，才能说明高维下线性可分）。\n",
    "\n",
    "#### soft margin & slack variable\n",
    "\n",
    "引入松弛变量$\\xi\\geq0$，对应数据点允许偏离的functional margin 的量。\n",
    "\n",
    "目标函数：$\\min\\ \\frac{1}{2}||w||^2+C\\sum\\xi_i\\qquad s.t.\\ y_i(w^Tx_i+b)\\geq1-\\xi_i$ \n",
    "\n",
    "对偶问题：\n",
    "\n",
    "$$\\max\\ \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)=\\min \\frac{1}{2}\\sum^m_{i,j=1}\\alpha_i\\alpha_jy_iy_j(x_ix_j)-\\sum^m_{i=1}\\alpha_i$$\n",
    "\n",
    "$$s.t.\\ C\\geq\\alpha_i \\geq 0,i=1,2,...,m\\quad \\sum^m_{i=1}\\alpha_iy_i=0,$$\n",
    "\n",
    "-----\n",
    "\n",
    "#### Sequential Minimal Optimization\n",
    "\n",
    "首先定义特征到结果的输出函数：$u=w^Tx+b$.\n",
    "\n",
    "因为$w=\\sum\\alpha_iy_ix_i$\n",
    "\n",
    "有$u=\\sum y_i\\alpha_iK(x_i, x)-b$\n",
    "\n",
    "\n",
    "----\n",
    "\n",
    "$\\max \\sum^m_{i=1}\\alpha_i-\\frac{1}{2}\\sum^m_{i=1}\\sum^m_{j=1}\\alpha_i\\alpha_jy_iy_j<\\phi(x_i)^T,\\phi(x_j)>$\n",
    "\n",
    "$s.t.\\ \\sum^m_{i=1}\\alpha_iy_i=0,$\n",
    "\n",
    "$ \\alpha_i \\geq 0,i=1,2,...,m$\n",
    "\n",
    "Reference:  \n",
    "https://www.youtube.com/watch?v=_PwhiWxHK8o  \n",
    "https://www.youtube.com/watch?v=vywmP6Ud1HA  \n",
    "https://www.youtube.com/watch?v=iB2VK7qPfjg\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import  train_test_split\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data\n",
    "def create_data():\n",
    "    iris = load_iris()\n",
    "    df = pd.DataFrame(iris.data, columns=iris.feature_names)\n",
    "    df['label'] = iris.target\n",
    "    df.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']\n",
    "    data = np.array(df.iloc[:100, [0, 1, -1]])\n",
    "    for i in range(len(data)):\n",
    "        if data[i,-1] == 0:\n",
    "            data[i,-1] = -1\n",
    "    # print(data)\n",
    "    return data[:,:2], data[:,-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = create_data()\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 161,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x117f1f2b0>"
      ]
     },
     "execution_count": 161,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAGXxJREFUeJzt3X+MXWWdx/H3d4dZOiowaRkWmCmWVdM/bLsWRrBpQlxxF8VaGmShjb+qrN01uGBwMdYQ1IYEDQaV1WhayALCVrsVu4XlxyIs8UekyZTWdrWQoIt2CixDsa2shW3Ld/+4d+jM7Z2597n3nnuf57mfV9J07rkPp9/nHP329pzPea65OyIikpc/6XQBIiLSemruIiIZUnMXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJEPH1TvQzHqAEWCPuy+peG8lcCOwp7zpm+5+y3T7O/nkk33OnDlBxYqIdLutW7e+4O4DtcbV3dyBq4BdwIlTvP99d/9UvTubM2cOIyMjAX+8iIiY2W/rGVfXZRkzGwLeB0z7aVxEROJQ7zX3rwOfBV6dZswHzGyHmW00s9nVBpjZKjMbMbORsbGx0FpFRKRONZu7mS0Bnnf3rdMMuweY4+4LgB8Bt1cb5O5r3X3Y3YcHBmpeMhIRkQbVc819MbDUzC4EZgAnmtmd7v6h8QHuvnfC+HXAV1pbpohI4w4dOsTo6Cgvv/xyp0up24wZMxgaGqK3t7eh/75mc3f31cBqADN7J/CPExt7eftp7v5s+eVSSjdeRUSiMDo6ygknnMCcOXMws06XU5O7s3fvXkZHRznzzDMb2kfDOXczW2NmS8svrzSzX5rZL4ArgZWN7ldEpNVefvllZs2alURjBzAzZs2a1dS/NEKikLj7o8Cj5Z+vm7D9tU/3IrnZtG0PNz74JM/sO8jp/X1cc8Fcli0c7HRZEiiVxj6u2XqDmrtIt9m0bQ+r797JwUNHANiz7yCr794JoAYvUdPyAyLTuPHBJ19r7OMOHjrCjQ8+2aGKJHVPPPEEixYt4vjjj+erX/1qYX+OPrmLTOOZfQeDtovUMnPmTG6++WY2bdpU6J+jT+4i0zi9vy9ou+Rh07Y9LP7yI5z5uX9n8ZcfYdO2PbX/ozqdcsopvP3tb2844lgvNXeRaVxzwVz6ensmbevr7eGaC+Z2qCIp2vh9lj37DuIcvc/SygbfDmruItNYtnCQGy6ez2B/HwYM9vdxw8XzdTM1Y7ncZ9E1d5Eali0cVDPvIkXcZ/nWt77FunXrALjvvvs4/fTTG95XvfTJXURkgiLus1xxxRVs376d7du3t6Wxg5q7iMgkRd9nee655xgaGuKmm27i+uuvZ2hoiAMHDrRk3xPpsoyIyATjl+CKeir51FNPZXR0tCX7mo6au4hIhRzus+iyjIhIhtTcRUQypOYuIpIhNXcRkQypuYuIZEjNXbJR5GJPIs36+Mc/zimnnMK8efPa8uepuUsWclnsSfK1cuVKHnjggbb9eWrukoVcFnuSSOzYAF+bB1/sL/2+Y0PTuzzvvPOYOXNmC4qrjx5ikizoSzWkZXZsgHuuhEPl/+3s3116DbDg0s7VFUif3CUL+lINaZmH1xxt7OMOHSxtT4iau2RBX6ohLbN/inVfptoeKV2WkSwUvdiTdJGThkqXYqptT4iau2Qjh8WeJALnXzf5mjtAb19pexNWrFjBo48+ygsvvMDQ0BBf+tKXuPzyy5ssdmpq7tK0Tdv26BOz5GP8punDa0qXYk4aKjX2Jm+mrl+/vgXF1U/NXZoyni8fjyGO58sBNXhJ14JLk0rGVKMbqtIU5ctF4qTmLk1RvlxS4e6dLiFIs/WquUtTlC+XFMyYMYO9e/cm0+Ddnb179zJjxoyG96Fr7tKUay6YO+maOyhfLvEZGhpidHSUsbGxTpdStxkzZjA01Hj8Us1dmqJ8uaSgt7eXM888s9NltFXdzd3MeoARYI+7L6l473jgDuBsYC9wmbs/3cI6JWLKl4vEJ+ST+1XALuDEKu9dDvze3d9sZsuBrwCXtaA+kaQo8y+xqOuGqpkNAe8DbpliyEXA7eWfNwLnm5k1X55IOrSmvMSk3rTM14HPAq9O8f4gsBvA3Q8D+4FZTVcnkhBl/iUmNZu7mS0Bnnf3rdMNq7LtmMyRma0ysxEzG0nprrVIPZT5l5jU88l9MbDUzJ4Gvge8y8zurBgzCswGMLPjgJOAFyt35O5r3X3Y3YcHBgaaKlwkNsr8S0xqNnd3X+3uQ+4+B1gOPOLuH6oYthn4aPnnS8pj0nhaQKRFtKa8xKThnLuZrQFG3H0zcCvwXTN7itIn9uUtqk8kGcr8S0ysUx+wh4eHfWRkpCN/tohIqsxsq7sP1xqnJ1QlWtdu2sn6Lbs54k6PGSvOnc31y+Z3uiyRJKi5S5Su3bSTOx/73Wuvj7i/9loNXqQ2rQopUVq/pcp3WE6zXUQmU3OXKB2Z4l7QVNtFZDI1d4lSzxSrV0y1XUQmU3OXKK04d3bQdhGZTDdUJUrjN02VlhFpjHLuIiIJUc5dmvLBdT/nZ78+ujzQ4jfN5K5PLOpgRZ2jNdolRbrmLseobOwAP/v1i3xw3c87VFHnaI12SZWauxyjsrHX2p4zrdEuqVJzF5mG1miXVKm5i0xDa7RLqtTc5RiL3zQzaHvOtEa7pErNXY5x1ycWHdPIuzUts2zhIDdcPJ/B/j4MGOzv44aL5ystI9FTzl1EJCHKuUtTisp2h+xX+XKRxqm5yzHGs93jEcDxbDfQVHMN2W9RNYh0C11zl2MUle0O2a/y5SLNUXOXYxSV7Q7Zr/LlIs1Rc5djFJXtDtmv8uUizVFzl2MUle0O2a/y5SLN0Q1VOcb4DctWJ1VC9ltUDSLdQjl3EZGEKOdesBgy2KE1xFCziLSHmnsDYshgh9YQQ80i0j66odqAGDLYoTXEULOItI+aewNiyGCH1hBDzSLSPmruDYghgx1aQww1i0j7qLk3IIYMdmgNMdQsIu2jG6oNiCGDHVpDDDWLSPvUzLmb2Qzgx8DxlP4y2OjuX6gYsxK4ERj/Svhvuvst0+1XOXcRkXCtzLm/ArzL3V8ys17gp2Z2v7s/VjHu++7+qUaKlfa4dtNO1m/ZzRF3esxYce5srl82v+mxseTnY6lDJAY1m7uXPtq/VH7ZW/7VmcdapWHXbtrJnY/97rXXR9xfe13ZtEPGxpKfj6UOkVjUdUPVzHrMbDvwPPCQu2+pMuwDZrbDzDaa2eyWVilNW79ld93bQ8bGkp+PpQ6RWNTV3N39iLu/DRgCzjGzeRVD7gHmuPsC4EfA7dX2Y2arzGzEzEbGxsaaqVsCHZni3kq17SFjY8nPx1KHSCyCopDuvg94FHhPxfa97v5K+eU64Owp/vu17j7s7sMDAwMNlCuN6jGre3vI2Fjy87HUIRKLms3dzAbMrL/8cx/wbuCJijGnTXi5FNjVyiKleSvOrX6lrNr2kLGx5OdjqUMkFvWkZU4DbjezHkp/GWxw93vNbA0w4u6bgSvNbClwGHgRWFlUwdKY8Ruh9SRgQsbGkp+PpQ6RWGg9dxGRhGg994IVlakOyZcXue+Q+aV4LJKzYwM8vAb2j8JJQ3D+dbDg0k5XJRFTc29AUZnqkHx5kfsOmV+KxyI5OzbAPVfCoXLyZ//u0mtQg5cpaeGwBhSVqQ7Jlxe575D5pXgskvPwmqONfdyhg6XtIlNQc29AUZnqkHx5kfsOmV+KxyI5+0fDtoug5t6QojLVIfnyIvcdMr8Uj0VyThoK2y6CmntDispUh+TLi9x3yPxSPBbJOf866K34y7K3r7RdZAq6odqAojLVIfnyIvcdMr8Uj0Vyxm+aKi0jAZRzFxFJiHLucowYsuuSOOXtk6Hm3iViyK5L4pS3T4puqHaJGLLrkjjl7ZOi5t4lYsiuS+KUt0+KmnuXiCG7LolT3j4pau5dIobsuiROefuk6IZql4ghuy6JU94+Kcq5i4gkRDn3sqLy2iH7jWVdcmXXI5N7Zjz3+YXowLHIurkXldcO2W8s65Irux6Z3DPjuc8vRIeORdY3VIvKa4fsN5Z1yZVdj0zumfHc5xeiQ8ci6+ZeVF47ZL+xrEuu7Hpkcs+M5z6/EB06Flk396Ly2iH7jWVdcmXXI5N7Zjz3+YXo0LHIurkXldcO2W8s65Irux6Z3DPjuc8vRIeORdY3VIvKa4fsN5Z1yZVdj0zumfHc5xeiQ8dCOXcRkYQo514w5edFEnHv1bD1NvAjYD1w9kpYclPz+408x6/m3gDl50USce/VMHLr0dd+5OjrZhp8Ajn+rG+oFkX5eZFEbL0tbHu9Esjxq7k3QPl5kUT4kbDt9Uogx6/m3gDl50USYT1h2+uVQI5fzb0Bys+LJOLslWHb65VAjl83VBug/LxIIsZvmrY6LZNAjl85dxGRhLQs525mM4AfA8eXx2909y9UjDkeuAM4G9gLXObuTzdQd02h+fLU1jAPya7nfiwKzRGHZJ+LqqPI+UWewW5K6NxyPhbTqOeyzCvAu9z9JTPrBX5qZve7+2MTxlwO/N7d32xmy4GvAJe1utjQfHlqa5iHZNdzPxaF5ohDss9F1VHk/BLIYDcsdG45H4saat5Q9ZKXyi97y78qr+VcBNxe/nkjcL5Z62Mbofny1NYwD8mu534sCs0Rh2Sfi6qjyPklkMFuWOjccj4WNdSVljGzHjPbDjwPPOTuWyqGDAK7Adz9MLAfmFVlP6vMbMTMRsbGxoKLDc2Xp7aGeUh2PfdjUWiOOCT7XFQdRc4vgQx2w0LnlvOxqKGu5u7uR9z9bcAQcI6ZzasYUu1T+jEdyd3Xuvuwuw8PDAwEFxuaL09tDfOQ7Hrux6LQHHFI9rmoOoqcXwIZ7IaFzi3nY1FDUM7d3fcBjwLvqXhrFJgNYGbHAScBL7agvklC8+WprWEekl3P/VgUmiMOyT4XVUeR80sgg92w0LnlfCxqqCctMwAccvd9ZtYHvJvSDdOJNgMfBX4OXAI84gVkLEPz5amtYR6SXc/9WBSaIw7JPhdVR5HzSyCD3bDQueV8LGqomXM3swWUbpb2UPqkv8Hd15jZGmDE3TeX45LfBRZS+sS+3N1/M91+lXMXEQnXspy7u++g1LQrt1834eeXgb8JLVJERIqR/fIDyT24I+0R8mBLDA/BFPngTmoPacVwPhKQdXNP7sEdaY+QB1tieAimyAd3UntIK4bzkYisV4VM7sEdaY+QB1tieAimyAd3UntIK4bzkYism3tyD+5Ie4Q82BLDQzBFPriT2kNaMZyPRGTd3JN7cEfaI+TBlhgeginywZ3UHtKK4XwkIuvmntyDO9IeIQ+2xPAQTJEP7qT2kFYM5yMRWTf3ZQsHueHi+Qz292HAYH8fN1w8XzdTu92CS+H9N8NJswEr/f7+m6vfkAsZG0O9oeOLml9q+82QvqxDRCQhLXuISaTrhXyxRyxSqzmW7HosdbSAmrvIdEK+2CMWqdUcS3Y9ljpaJOtr7iJNC/lij1ikVnMs2fVY6mgRNXeR6YR8sUcsUqs5lux6LHW0iJq7yHRCvtgjFqnVHEt2PZY6WkTNXWQ6IV/sEYvUao4lux5LHS2i5i4ynSU3wfDlRz/1Wk/pdYw3JselVnMs2fVY6mgR5dxFRBKinLu0T4rZ4KJqLipfnuIxlo5Sc5fmpJgNLqrmovLlKR5j6Thdc5fmpJgNLqrmovLlKR5j6Tg1d2lOitngomouKl+e4jGWjlNzl+akmA0uquai8uUpHmPpODV3aU6K2eCiai4qX57iMZaOU3OX5qSYDS6q5qLy5SkeY+k45dxFRBJSb85dn9wlHzs2wNfmwRf7S7/v2ND+/RZVg0gg5dwlD0VlwUP2qzy6RESf3CUPRWXBQ/arPLpERM1d8lBUFjxkv8qjS0TU3CUPRWXBQ/arPLpERM1d8lBUFjxkv8qjS0TU3CUPRWXBQ/arPLpEpGbO3cxmA3cApwKvAmvd/RsVY94J/Bvw3+VNd7v7tHeRlHMXEQnXyvXcDwOfcffHzewEYKuZPeTuv6oY9xN3X9JIsRKhFNcPD6k5xfnFQMctGTWbu7s/Czxb/vkPZrYLGAQqm7vkIsW8tvLoxdNxS0rQNXczmwMsBLZUeXuRmf3CzO43s7e2oDbplBTz2sqjF0/HLSl1P6FqZm8AfgB82t0PVLz9OPBGd3/JzC4ENgFvqbKPVcAqgDPOOKPhoqVgKea1lUcvno5bUur65G5mvZQa+13ufnfl++5+wN1fKv98H9BrZidXGbfW3YfdfXhgYKDJ0qUwKea1lUcvno5bUmo2dzMz4FZgl7tXXbvUzE4tj8PMzinvd28rC5U2SjGvrTx68XTcklLPZZnFwIeBnWa2vbzt88AZAO7+HeAS4JNmdhg4CCz3Tq0lLM0bvzmWUioipOYU5xcDHbekaD13EZGEtDLnLrFS5niye6+GrbeVvpDaekpfb9fstyCJJErNPVXKHE9279UwcuvR137k6Gs1eOlCWlsmVcocT7b1trDtIplTc0+VMseT+ZGw7SKZU3NPlTLHk1lP2HaRzKm5p0qZ48nOXhm2XSRzau6p0trhky25CYYvP/pJ3XpKr3UzVbqUcu4iIglRzr0Bm7bt4cYHn+SZfQc5vb+Pay6Yy7KFg50uq3Vyz8XnPr8Y6BgnQ829bNO2Pay+eycHD5XSFXv2HWT13TsB8mjwuefic59fDHSMk6Jr7mU3Pvjka4193MFDR7jxwSc7VFGL5Z6Lz31+MdAxToqae9kz+w4GbU9O7rn43OcXAx3jpKi5l53e3xe0PTm55+Jzn18MdIyTouZeds0Fc+nrnfzAS19vD9dcMLdDFbVY7rn43OcXAx3jpOiGatn4TdNs0zK5r8Wd+/xioGOcFOXcRUQSUm/OXZdlRFKwYwN8bR58sb/0+44NaexbOkaXZURiV2S+XNn1bOmTu0jsisyXK7ueLTV3kdgVmS9Xdj1bau4isSsyX67serbU3EViV2S+XNn1bKm5i8SuyLX79b0A2VLOXUQkIcq5i4h0MTV3EZEMqbmLiGRIzV1EJENq7iIiGVJzFxHJkJq7iEiG1NxFRDJUs7mb2Wwz+08z22VmvzSzq6qMMTO72cyeMrMdZnZWMeVKU7Rut0jXqGc998PAZ9z9cTM7AdhqZg+5+68mjHkv8Jbyr3OBb5d/l1ho3W6RrlLzk7u7P+vuj5d//gOwC6j8YtGLgDu85DGg38xOa3m10jit2y3SVYKuuZvZHGAhsKXirUFg94TXoxz7FwBmtsrMRsxsZGxsLKxSaY7W7RbpKnU3dzN7A/AD4NPufqDy7Sr/yTErkrn7WncfdvfhgYGBsEqlOVq3W6Sr1NXczayXUmO/y93vrjJkFJg94fUQ8Ezz5UnLaN1uka5ST1rGgFuBXe5+0xTDNgMfKadm3gHsd/dnW1inNEvrdot0lXrSMouBDwM7zWx7edvngTMA3P07wH3AhcBTwB+Bj7W+VGnagkvVzEW6RM3m7u4/pfo19YljHLiiVUWJiEhz9ISqiEiG1NxFRDKk5i4ikiE1dxGRDKm5i4hkSM1dRCRDau4iIhmyUkS9A3+w2Rjw24784bWdDLzQ6SIKpPmlK+e5geZXjze6e83FuTrW3GNmZiPuPtzpOoqi+aUr57mB5tdKuiwjIpIhNXcRkQypuVe3ttMFFEzzS1fOcwPNr2V0zV1EJEP65C4ikqGubu5m1mNm28zs3irvrTSzMTPbXv71t52osRlm9rSZ7SzXP1LlfTOzm83sKTPbYWZndaLORtQxt3ea2f4J5y+pr5wys34z22hmT5jZLjNbVPF+sucO6ppfsufPzOZOqHu7mR0ws09XjCn8/NXzZR05uwrYBZw4xfvfd/dPtbGeIvylu0+Vq30v8Jbyr3OBb5d/T8V0cwP4ibsvaVs1rfUN4AF3v8TM/hR4XcX7qZ+7WvODRM+fuz8JvA1KHyCBPcAPK4YVfv669pO7mQ0B7wNu6XQtHXQRcIeXPAb0m9lpnS6q25nZicB5lL7eEnf/P3ffVzEs2XNX5/xycT7wa3evfGCz8PPXtc0d+DrwWeDVacZ8oPxPpo1mNnuacbFy4D/MbKuZrary/iCwe8Lr0fK2FNSaG8AiM/uFmd1vZm9tZ3FN+nNgDPjn8mXDW8zs9RVjUj539cwP0j1/Ey0H1lfZXvj568rmbmZLgOfdfes0w+4B5rj7AuBHwO1tKa61Frv7WZT+CXiFmZ1X8X61r09MJT5Va26PU3pM+y+AfwI2tbvAJhwHnAV8290XAv8LfK5iTMrnrp75pXz+AChfbloK/Gu1t6tsa+n568rmTulLv5ea2dPA94B3mdmdEwe4+153f6X8ch1wdntLbJ67P1P+/XlK1/zOqRgyCkz8F8kQ8Ex7qmtOrbm5+wF3f6n8831Ar5md3PZCGzMKjLr7lvLrjZSaYeWYJM8ddcwv8fM37r3A4+7+P1XeK/z8dWVzd/fV7j7k7nMo/bPpEXf/0MQxFde/llK68ZoMM3u9mZ0w/jPw18B/VQzbDHykfOf+HcB+d3+2zaUGq2duZnaqmVn553Mo/W99b7trbYS7PwfsNrO55U3nA7+qGJbkuYP65pfy+ZtgBdUvyUAbzl+3p2UmMbM1wIi7bwauNLOlwGHgRWBlJ2trwJ8BPyz//+M44F/c/QEz+3sAd/8OcB9wIfAU8EfgYx2qNVQ9c7sE+KSZHQYOAss9rSf2/gG4q/xP+98AH8vk3I2rNb+kz5+ZvQ74K+DvJmxr6/nTE6oiIhnqyssyIiK5U3MXEcmQmruISIbU3EVEMqTmLiKSITV3EZEMqbmLiGRIzV1EJEP/D+1KgcwTy4s9AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.scatter(X[:50,0],X[:50,1], label='-1')\n",
    "plt.scatter(X[50:,0],X[50:,1], label='1')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----\n",
    "##### SMO算法\n",
    "算法7.5 P130"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SVM:\n",
    "    def __init__(self, max_iter=100, epsilon=0.001, C=1.0, kernel='linear'):\n",
    "        self.max_iter = max_iter\n",
    "        self.kernel = kernel\n",
    "        self.epsilon = epsilon\n",
    "        self.C = C\n",
    "    \n",
    "    def _init_parameters(self, X, y):\n",
    "        '''\n",
    "        初始化一些参数\n",
    "        '''\n",
    "        self.X = X\n",
    "        self.y = y\n",
    "\n",
    "        self.b = 0.0\n",
    "        self.M, self.N = X.shape\n",
    "        self.alpha = np.ones(self.M)\n",
    "        self.E = [self._E(i) for i in range(self.M)]\n",
    "\n",
    "    def _kernel(self, x1, x2):\n",
    "        #核函数\n",
    "        if self.kernel == 'linear':\n",
    "            return np.dot(x1, x2)\n",
    "    \n",
    "    def _gx(self, i):\n",
    "        # g(x_i) 公式7.104\n",
    "        #return np.sum(self.alpha * self.y * self._kernel(self.X, self.X[i]) + self.b)\n",
    "        \n",
    "        r = self.b\n",
    "        for j in range(self.M):\n",
    "            r += self.alpha[j]*self.y[j]*self._kernel(self.X[i], self.X[j])\n",
    "        return r\n",
    "    \n",
    "    def _E(self, i):\n",
    "        # 公式 7.105\n",
    "        return self._gx(i) - self.y[i]\n",
    "    \n",
    "    def _KKT(self, i):\n",
    "        # P130\n",
    "        ygx = self.y[i] * self._gx(i)\n",
    "        if self.alpha[i] == 0:\n",
    "            return ygx >= 1\n",
    "        elif 0 < self.alpha[i] < self.C:\n",
    "            return ygx ==1\n",
    "        else:\n",
    "            return ygx <= 1\n",
    "        \n",
    "    def _init_alpha(self):\n",
    "        # 按照书上7.4.2选择两个变量\n",
    "        # 外层循环首先遍历所有满足0<a<C的样本点，检验是否满足KKT\n",
    "        index_list = [i for i in range(self.M) if 0 < self.alpha[i] < self.C]\n",
    "        # 否则遍历整个训练集\n",
    "        non_satisfy_list = [i for i in range(self.M) if i not in index_list]\n",
    "        index_list.extend(non_satisfy_list)# extend操作后，满足[0,C]的样本在前，不满足的样本在后，检查KKT条件，停在满足[0,C]或后半部分样本点上\n",
    "        \n",
    "        for i in index_list:\n",
    "            if self._KKT(i):\n",
    "                continue\n",
    "            E1 = self.E[i]\n",
    "            # 如果E2是+，选择最小的；如果E2是负的，选择最大的\n",
    "            if E1 >= 0:\n",
    "                j = np.argmin(self.E)\n",
    "            else:\n",
    "                j = np.argmax(self.E)\n",
    "            return i, j\n",
    "        \n",
    "    def _clip(self, alpha, L, H):\n",
    "        if alpha > H:\n",
    "            return H\n",
    "        elif alpha < L:\n",
    "            return L\n",
    "        else:\n",
    "            return alpha\n",
    "        \n",
    "    def fit(self, X, y):\n",
    "        self._init_parameters(X, y)\n",
    "        \n",
    "        for _iter in range(self.max_iter):\n",
    "            i1, i2 = self._init_alpha()\n",
    "            \n",
    "            #bound, P126\n",
    "            if self.y[i1] == self.y[i2]:\n",
    "                L = np.max((0, self.alpha[i2] + self.alpha[i1] - self.C))\n",
    "                H = np.min((self.C, self.alpha[i2] + self.alpha[i1]))\n",
    "            else:\n",
    "                L = np.max((0, self.alpha[i2] - self.alpha[i1]))\n",
    "                H = np.min((self.C, self.C + self.alpha[i2] - self.alpha[i1]))\n",
    "                \n",
    "            E1 = self.E[i1]\n",
    "            E2 = self.E[i2]\n",
    "            \n",
    "            #eta = K11 + K22 - 2K12, 7.107\n",
    "            eta = self._kernel(self.X[i1], self.X[i1]) + self._kernel(self.X[i2], self.X[i2]) - \\\n",
    "            2 * self._kernel(self.X[i1], self.X[i2])\n",
    "            \n",
    "            alpha2_new_unc = self.alpha[i2] + self.y[i2] * (E1 - E2) / (eta + 1e-4)  # 7.106\n",
    "            \n",
    "            alpha2_new = self._clip(alpha2_new_unc, L, H) # 7.108\n",
    "            \n",
    "            alpha1_new = self.alpha[i1] + self.y[i1] * self.y[i2] * (self.alpha[i2] - alpha2_new) # 7.109\n",
    "            \n",
    "            b1_new = -E1 - self.y[i1] * self._kernel(self.X[i1], self.X[i1]) * (alpha1_new - self.alpha[i1]) - \\\n",
    "            self.y[i2] * self._kernel(self.X[i2], self.X[i1]) * (alpha2_new - self.alpha[i2]) + self.b # 7.115\n",
    "            \n",
    "            b2_new = -E2 - self.y[i1] * self._kernel(self.X[i1], self.X[i2]) * (alpha1_new - self.alpha[i1]) - \\\n",
    "            self.y[i2] * self._kernel(self.X[i2], self.X[i2]) * (alpha2_new - self.alpha[i2]) + self.b # 7.116\n",
    "            \n",
    "            if 0 < alpha1_new < self.C and 0 < alpha2_new < self.C:\n",
    "                b_new = b1_new\n",
    "            else:\n",
    "                b_new = (b1_new + b2_new) / 2 # 中点， P130\n",
    "                \n",
    "            # update parameters\n",
    "            self.alpha[i1] = alpha1_new\n",
    "            self.alpha[i2] = alpha2_new\n",
    "            self.b = b_new\n",
    "            \n",
    "            self.E[i1] = self._E(i1)\n",
    "            self.E[i2] = self._E(i2)\n",
    "            \n",
    "        return 'Done.'\n",
    "    \n",
    "    def predict(self, data):\n",
    "        r = self.b\n",
    "        for i in range(self.M):\n",
    "            r += self.alpha[i] * self.y[i] * self._kernel(data, self.X[i])\n",
    "            \n",
    "        return 1 if r > 0 else -1\n",
    "        \n",
    "    def score(self, X_test, y_test):\n",
    "        right_item = 0\n",
    "        for i in range(len(X_test)):\n",
    "            res = self.predict(X_test[i])\n",
    "            if res == y_test[i]:\n",
    "                right_item += 1\n",
    "        return right_item / len(X_test)\n",
    "    \n",
    "    def _weight(self):\n",
    "        yx = self.y.reshape(-1, 1) * self.X\n",
    "        self.w = np.dot(yx.T, self.alpha)\n",
    "        return self.w, self.b\n",
    "            \n",
    "\n",
    "#https://blog.csdn.net/wds2006sdo/article/details/53156589\n",
    "#https://github.com/fengdu78/lihang-code/blob/master/code/%E7%AC%AC7%E7%AB% \\\n",
    "#A0%20%E6%94%AF%E6%8C%81%E5%90%91%E9%87%8F%E6%9C%BA(SVM)/support-vector-machine.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "metadata": {},
   "outputs": [],
   "source": [
    "svm = SVM(max_iter=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Done.'"
      ]
     },
     "execution_count": 157,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "svm.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 158,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "svm.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([ 3.6, -5.7]), -3.8699999999999815)"
      ]
     },
     "execution_count": 159,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "svm._weight() #array([ 3.6, -5.7])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## sklearn.svm.SVC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 169,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/max/anaconda2/envs/pytorch/lib/python3.6/site-packages/sklearn/svm/base.py:196: FutureWarning: The default value of gamma will change from 'auto' to 'scale' in version 0.22 to account better for unscaled features. Set gamma explicitly to 'auto' or 'scale' to avoid this warning.\n",
      "  \"avoid this warning.\", FutureWarning)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,\n",
       "  decision_function_shape='ovr', degree=3, gamma='auto_deprecated',\n",
       "  kernel='rbf', max_iter=-1, probability=False, random_state=None,\n",
       "  shrinking=True, tol=0.001, verbose=False)"
      ]
     },
     "execution_count": 169,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.svm import SVC\n",
    "clf = SVC()\n",
    "clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 170,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 170,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### sklearn.svm.SVC\n",
    "\n",
    "*(C=1.0, kernel='rbf', degree=3, gamma='auto', coef0=0.0, shrinking=True, probability=False,tol=0.001, cache_size=200, class_weight=None, verbose=False, max_iter=-1, decision_function_shape=None,random_state=None)*\n",
    "\n",
    "参数：\n",
    "\n",
    "- C：C-SVC的惩罚参数C?默认值是1.0\n",
    "\n",
    "C越大，相当于惩罚松弛变量，希望松弛变量接近0，即对误分类的惩罚增大，趋向于对训练集全分对的情况，这样对训练集测试时准确率很高，但泛化能力弱。C值小，对误分类的惩罚减小，允许容错，将他们当成噪声点，泛化能力较强。\n",
    "\n",
    "- kernel ：核函数，默认是rbf，可以是‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’ \n",
    "    \n",
    "    – 线性：u'v\n",
    "    \n",
    "    – 多项式：(gamma*u'*v + coef0)^degree\n",
    "\n",
    "    – RBF函数：exp(-gamma|u-v|^2)\n",
    "\n",
    "    – sigmoid：tanh(gamma*u'*v + coef0)\n",
    "\n",
    "\n",
    "- degree ：多项式poly函数的维度，默认是3，选择其他核函数时会被忽略。\n",
    "\n",
    "\n",
    "- gamma ： ‘rbf’,‘poly’ 和‘sigmoid’的核函数参数。默认是’auto’，则会选择1/n_features\n",
    "\n",
    "\n",
    "- coef0 ：核函数的常数项。对于‘poly’和 ‘sigmoid’有用。\n",
    "\n",
    "\n",
    "- probability ：是否采用概率估计？.默认为False\n",
    "\n",
    "\n",
    "- shrinking ：是否采用shrinking heuristic方法，默认为true\n",
    "\n",
    "\n",
    "- tol ：停止训练的误差值大小，默认为1e-3\n",
    "\n",
    "\n",
    "- cache_size ：核函数cache缓存大小，默认为200\n",
    "\n",
    "\n",
    "- class_weight ：类别的权重，字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)\n",
    "\n",
    "\n",
    "- verbose ：允许冗余输出？\n",
    "\n",
    "\n",
    "- max_iter ：最大迭代次数。-1为无限制。\n",
    "\n",
    "\n",
    "- decision_function_shape ：‘ovo’, ‘ovr’ or None, default=None3\n",
    "\n",
    "\n",
    "- random_state ：数据洗牌时的种子值，int值\n",
    "\n",
    "\n",
    "主要调节的参数有：C、kernel、degree、gamma、coef0。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
